/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package org.apache.druid.sql.calcite.aggregation.builtin;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlFunctionCategory;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.type.InferTypes;
import org.apache.calcite.sql.type.OperandTypes;
import org.apache.calcite.sql.type.ReturnTypes;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.util.Optionality;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.cardinality.CardinalityAggregatorFactory;
import org.apache.druid.query.aggregation.hyperloglog.HyperUniqueFinalizingPostAggregator;
import org.apache.druid.query.aggregation.hyperloglog.HyperUniquesAggregatorFactory;
import org.apache.druid.query.dimension.DefaultDimensionSpec;
import org.apache.druid.query.dimension.DimensionSpec;
import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.ValueType;
import org.apache.druid.sql.calcite.aggregation.Aggregation;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.expression.Expressions;
import org.apache.druid.sql.calcite.planner.Calcites;
import org.apache.druid.sql.calcite.planner.PlannerConfig;
import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.InputAccessor;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import org.apache.druid.sql.calcite.table.RowSignatures;

import javax.annotation.Nullable;
import java.util.Collections;
import java.util.List;
import java.util.Objects;

public class BuiltinApproxCountDistinctSqlAggregator implements SqlAggregator
{
  public static final String NAME = "APPROX_COUNT_DISTINCT_BUILTIN";

  private static final SqlAggFunction FUNCTION_INSTANCE = new BuiltinApproxCountDistinctSqlAggFunction();

  @Override
  public SqlAggFunction calciteFunction()
  {
    return FUNCTION_INSTANCE;
  }

  @Nullable
  @Override
  public Aggregation toDruidAggregation(
      final PlannerContext plannerContext,
      final VirtualColumnRegistry virtualColumnRegistry,
      final String name,
      final AggregateCall aggregateCall,
      final InputAccessor inputAccessor,
      final List<Aggregation> existingAggregations,
      final boolean finalizeAggregations
  )
  {
    // Don't use Aggregations.getArgumentsForSimpleAggregator, since it won't let us use direct column access
    // for string columns.
    final RexNode rexNode = inputAccessor.getField(
        Iterables.getOnlyElement(aggregateCall.getArgList()));

    final DruidExpression arg = Expressions.toDruidExpression(plannerContext, inputAccessor.getInputRowSignature(), rexNode);
    if (arg == null) {
      return null;
    }

    final AggregatorFactory aggregatorFactory;
    final String aggregatorName = finalizeAggregations ? Calcites.makePrefixedName(name, "a") : name;

    if (arg.isDirectColumnAccess()
        && inputAccessor.getInputRowSignature()
            .getColumnType(arg.getDirectColumn())
            .map(this::isValidComplexInputType)
            .orElse(false)) {
      aggregatorFactory = new HyperUniquesAggregatorFactory(aggregatorName, arg.getDirectColumn(), false, true);
    } else {
      final RelDataType dataType = rexNode.getType();
      final ColumnType inputType = Calcites.getColumnTypeForRelDataType(dataType);
      if (inputType == null) {
        throw new ISE(
            "Cannot translate sqlTypeName[%s] to Druid type for field[%s]",
            dataType.getSqlTypeName(),
            aggregatorName
        );
      }

      final DimensionSpec dimensionSpec;

      if (arg.isSimpleExtraction()) {
        dimensionSpec = arg.getSimpleExtraction().toDimensionSpec(null, inputType);
      } else {
        String virtualColumnName = virtualColumnRegistry.getOrCreateVirtualColumnForExpression(arg, dataType);
        dimensionSpec = new DefaultDimensionSpec(virtualColumnName, null, inputType);
      }

      if (inputType.is(ValueType.COMPLEX)) {
        if (!isValidComplexInputType(inputType)) {
          plannerContext.setPlanningError(
              "Using APPROX_COUNT_DISTINCT() or enabling approximation with COUNT(DISTINCT) is not supported for"
              + " column type [%s]. You can disable approximation by setting [%s: false] in the query context.",
              arg.getDruidType(),
              PlannerConfig.CTX_KEY_USE_APPROXIMATE_COUNT_DISTINCT
          );
          return null;
        }
        aggregatorFactory = new HyperUniquesAggregatorFactory(
            aggregatorName,
            dimensionSpec.getOutputName(),
            false,
            true
        );
      } else {
        aggregatorFactory = new CardinalityAggregatorFactory(
            aggregatorName,
            null,
            ImmutableList.of(dimensionSpec),
            false,
            true
        );
      }
    }

    return Aggregation.create(
        Collections.singletonList(aggregatorFactory),
        finalizeAggregations ? new HyperUniqueFinalizingPostAggregator(name, aggregatorFactory.getName()) : null
    );
  }

  private static class BuiltinApproxCountDistinctSqlAggFunction extends SqlAggFunction
  {
    BuiltinApproxCountDistinctSqlAggFunction()
    {
      super(
          NAME,
          null,
          SqlKind.OTHER_FUNCTION,
          ReturnTypes.explicit(SqlTypeName.BIGINT),
          InferTypes.VARCHAR_1024,
          OperandTypes.or(
              OperandTypes.STRING,
              OperandTypes.NUMERIC,
              RowSignatures.complexTypeChecker(HyperUniquesAggregatorFactory.TYPE)
          ),
          SqlFunctionCategory.STRING,
          false,
          false,
          Optionality.FORBIDDEN
      );
    }
  }

  private boolean isValidComplexInputType(ColumnType columnType)
  {
    return Objects.equals(columnType.getComplexTypeName(), HyperUniquesAggregatorFactory.TYPE.getComplexTypeName()) ||
           Objects.equals(columnType.getComplexTypeName(), HyperUniquesAggregatorFactory.PRECOMPUTED_TYPE.getComplexTypeName());
  }
}
